import pandas as pd
import matplotlib.pyplot as plt

# Define the file generated by the C++ program
CSV_FILENAME = "regret_data_all_runs.csv"

try:
    # Load the wide-format data
    data = pd.read_csv(CSV_FILENAME)
except FileNotFoundError:
    print(f"Error: '{CSV_FILENAME}' not found. Please run the C++ program first.")
    exit()

plt.figure(figsize=(12, 8))

# --- Define algorithms, their column prefixes, colors, and labels ---
algorithms = {  
    'AMB': {
        'prefix': 'AMB_',
        'color': 'b',
        'label': 'AMB'
    },
    'QUL': {
        'prefix': 'QUL_',
        'color': 'purple',
        'label': 'ULCB-Hoeffding'
    },
    'RAMB': {
        'prefix': 'RAMB_',
        'color': 'g',
        'label': 'Refined AMB'
    },
    'UCB-H': {
        'prefix': 'UCB-H_',
        'color': 'r',
        'label': 'UCB-Hoeffding'
    }
}

# --- Loop through each algorithm to calculate stats and plot ---
for name, details in algorithms.items():
    print(f"Processing {name}...")
    
    # Filter columns for the current algorithm (e.g., 'UCB-H_1', 'UCB-H_2', ...)
    algo_cols = [col for col in data.columns if col.startswith(details['prefix'])]
    
    if not algo_cols:
        print(f"Warning: No data columns found for algorithm '{name}'. Skipping.")
        continue
        
    algo_data = data[algo_cols]
    
    median = algo_data.median(axis=1)
    q10 = algo_data.quantile(0.1, axis=1)
    q90 = algo_data.quantile(0.9, axis=1)
    
    # Plot the median line
    plt.plot(data["Episode"], median, color=details['color'], label=details['label'])
    
    plt.fill_between(data["Episode"], q10, q90, color=details['color'], alpha=0.15, linewidth=0)


# --- Final plot styling ---
plt.legend(loc='best', fontsize=18)
plt.xlabel("Total number of episodes $K$", fontsize=18)
plt.ylabel("Regret$(T)$ / log$(K+1)$", fontsize=18) 
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.tight_layout()

# Save the figure with a consistent name
OUTPUT_FILENAME = "regret_result_with_quantiles.jpg"
plt.savefig(OUTPUT_FILENAME, dpi=600)

print(f"\nPlot saved to {OUTPUT_FILENAME}")
plt.show()